library(ggplot2)
library(dplyr)
library(readr)
library(scales)
library(tidyr)
library(patchwork)
library(tibble)
library(purrr)
library(countrycode)

df <- read_csv("scad_country_final.csv")

# Variance of fatalities_mean by country
country_var <- df %>%
  group_by(ccode) %>%
  summarise(var_fatalities = var(ndeath_log, na.rm = TRUE),
            n_events = n()) %>%
  arrange(desc(var_fatalities))

print(head(country_var, 20), 20)

selected_country <- "437"

df_plot <- df %>%
  filter(ccode == selected_country) %>%
  arrange(startdate) %>%
  mutate(startdate = as.Date(startdate)) %>%
  mutate(cumulated_fatal_weighted_log = log1p(cumulated_fatal_weighted))

# -------------------------------
p_top <- ggplot(df_plot, aes(x = startdate)) +
  geom_line(aes(y = ndeath_log, color = "Daily Fatalities"), size = 1.0) +
  geom_line(aes(y = cumulated_fatal_weighted, color = "Reference Point"), size = 1.0) +
  labs(
    x = NULL,
    y = "Fatalities (logged)",
    color = NULL
  ) +
  theme_minimal() +
  scale_color_manual(values = c("Daily Fatalities" = "steelblue",
                                "Reference Point" = "firebrick")) +
  scale_y_continuous(expand = expansion(mult = c(0, 0.1))) +
  theme(legend.position = "top", legend.justification = "left",
        text = element_text(size = 14),
        plot.title = element_text(face = "bold", size = 16, hjust = 0.5),
        axis.title.x = element_text(size = 15, margin = margin(t = 15)),
        axis.title.y = element_text(size = 15, margin = margin(r = 15)),
        axis.text = element_text(size = 15),
        legend.text = element_text(size = 15))

# -------------------------------
p_bottom <- ggplot(df_plot, aes(x = startdate)) +
  geom_line(aes(y = fatality_diff_signed_weighted, color = "Discrepancy"), size = 1.0) +
  geom_hline(yintercept = 0, linetype = "dashed", color = "gray40") +
  labs(
    x = "Event Date",
    y = "Discrepancy (logged)",
    color = NULL
  ) +
  theme_minimal() +
  scale_color_manual(values = c("Discrepancy" = "darkgreen")) +
  theme(legend.position = "top", legend.justification = "left",
        text = element_text(size = 14),
        plot.title = element_text(face = "bold", size = 16, hjust = 0.5),
        axis.title.x = element_text(size = 15, margin = margin(t = 15)),
        axis.title.y = element_text(size = 15, margin = margin(r = 15)),
        axis.text = element_text(size = 15),
        legend.text = element_text(size = 15))

# -------------------------------
combined_plot <- p_top / p_bottom + plot_layout(heights = c(2, 1))

print(combined_plot)

# -------------------------------
make_fit <- function(y, x, data, min_n = 10) {
  ok <- complete.cases(data[[y]], data[[x]])
  n_ok <- sum(ok)
  if (n_ok < min_n) {
    return(list(fit = NULL, ok = ok, n = n_ok))
  } else {
    fit <- lm(reformulate(x, y), data = data[ok, ])
    return(list(fit = fit, ok = ok, n = n_ok))
  }
}

rmse <- function(y, yhat) sqrt(mean((y - yhat)^2, na.rm = TRUE))

fit_metrics_one_ccode_fair <- function(dat, min_train = 10, min_eval = 10) {
  d <- dat %>%
    arrange(startdate) %>%
    mutate(
      startdate = as.Date(startdate),
      mobil    = participants_max,             
      abs_rep  = ndeath_log,                   
      disc_lin = fatality_diff_signed_weighted  
    )
  
  res_abs  <- make_fit("mobil", "abs_rep",  d, min_n = min_train)
  res_disc <- make_fit("mobil", "disc_lin", d, min_n = min_train)
  
  
  eval_idx <- res_abs$ok & res_disc$ok
  n_eval   <- sum(eval_idx)
  
  
  AIC_abs <- BIC_abs <- AIC_disc <- BIC_disc <- RMSE_abs <- RMSE_disc <- NA_real_
  
  if (!is.null(res_abs$fit))  { AIC_abs  <- AIC(res_abs$fit);  BIC_abs  <- BIC(res_abs$fit) }
  if (!is.null(res_disc$fit)) { AIC_disc <- AIC(res_disc$fit); BIC_disc <- BIC(res_disc$fit) }
  
  if (!is.null(res_abs$fit) && !is.null(res_disc$fit) && n_eval >= min_eval) {
    y_eval    <- d$mobil[eval_idx]
    pred_abs  <- predict(res_abs$fit,  newdata = d[eval_idx, ])
    pred_disc <- predict(res_disc$fit, newdata = d[eval_idx, ])
    RMSE_abs  <- rmse(y_eval, pred_abs)
    RMSE_disc <- rmse(y_eval, pred_disc)
  }
  
  tibble(
    ccode    = unique(d$ccode)[1],
    n_total  = nrow(d),
    n_abs    = res_abs$n,
    n_disc   = res_disc$n,
    n_eval   = n_eval,                
    RMSE_abs = RMSE_abs, 
    RMSE_disc = RMSE_disc,
    dRMSE    = RMSE_disc - RMSE_abs   
  )
}

by_ccode <- df %>%
  group_by(ccode) %>%
  group_modify(~ fit_metrics_one_ccode_fair(.x, min_train = 10, min_eval = 10)) %>%
  ungroup()

by_ccode <- by_ccode %>% filter(!is.na(dRMSE))

stats <- by_ccode %>% filter(!is.na(dRMSE))

n_valid <- nrow(by_ccode)
n_abs  <- sum(by_ccode$dRMSE > 0, na.rm = TRUE)  
n_disc <- sum(by_ccode$dRMSE < 0, na.rm = TRUE)  
n_tie <- sum(by_ccode$dRMSE == 0, na.rm = TRUE) 

subtitle_text <- paste0(
  "Negative Values → Discrepancy Measure Has Lower RMSE\n",
  "Absolute Measure Better: ",  n_abs,  " (", percent(n_abs/n_valid,  accuracy = 0.1), ")",
  "\nDiscrepancy Measure Better: ", n_disc, " (", percent(n_disc/n_valid, accuracy = 0.1), ")"
)

plot_df <- by_ccode %>%
  mutate(cat = case_when(
    is.na(dRMSE)     ~ NA_character_,
    dRMSE < 0        ~ "Discrepancy Measure Better",
    dRMSE > 0        ~ "Absolute Measure Better",
    TRUE             ~ "Tie"
  ))

plot_df <- plot_df %>%
  mutate(country = countrycode(ccode, origin = "cown", destination = "country.name"))

ggplot(plot_df, aes(x = reorder(country, dRMSE), y = dRMSE, fill = cat)) +
  geom_col() +
  coord_flip() +
  geom_hline(yintercept = 0, linetype = "dashed", color = "gray50") +
  scale_fill_manual(
    values = c(
      "Absolute Measure Better"    = "steelblue",
      "Discrepancy Measure Better" = "darkgreen",
      "Tie"          = "gray60"
    ),
    na.translate = FALSE, name = NULL
  ) +
  labs(
    x = NULL,
    y = "ΔRMSE (Discrepancy − Absolute)",
    #title = "SCAD — RMSE Difference by Country",
    subtitle = subtitle_text
  ) +
  theme_minimal() +
  theme(legend.position = "top", legend.justification = "left",
        text = element_text(size = 14),
        plot.title = element_text(face = "bold", size = 16, hjust = 0.5),
        axis.title.x = element_text(size = 15, margin = margin(t = 15)),
        axis.title.y = element_text(size = 15, margin = margin(r = 15), family = "Times New Roman"),
        axis.text = element_text(size = 15),
        legend.text = element_text(size = 15))

ggsave("RMSE.jpeg", width = 10, height = 14, dpi = 300, bg = "white")

